# -*- coding: utf-8 -*-


import math
import numpy as np
from server_placement.server_placement_basic import ServerPlacementBasic

class ServerPlacementUniform(ServerPlacementBasic):
    
    """
    assign the server in the uniform distribution
    """
    
    def __init__(self, old_lati_list, old_long_list, compute_num):
        """
        
        param
        ----------
            old_lati_list : list
                store the lantitude of all the base station
            old_long_list : list
                store the longtitude of all the base station
            compute_num : int
                the number of the servers we want to assign
        """
        ServerPlacementBasic._init_(self, compute_num)
        self.old_lati_list = old_lati_list
        self.old_long_list = old_long_list
        
    def server_placement(self):
        """we assign the server uniformly
        
        param
        ----------
            We can get them from the init    
        
        return
        ----------
            final_arrange : list
                the distribution of the servers in the base station
        """
        lati_list = []
        long_list = []
        remain_list = []
        # delete the locations that are not proper for placement
        for i in range(len(self.old_lati_list)):
            if (self.old_lati_list[i] < 32 and self.old_lati_list[i] > 30 
                and self.old_long_list[i] < 122 and self.old_long_list[i] > 120):
                lati_list.append(self.old_lati_list[i])
                long_list.append(self.old_long_list[i])
                remain_list.append(i)        
    
        max_lati = max(lati_list)
        min_lati = min(lati_list)
        max_long = max(long_list)
        min_long = min(long_list)
        area_length = int(math.floor(math.sqrt(self.compute_num)))
        lati_length = (max_lati - min_lati) / area_length
        long_length = (max_long - min_long) / area_length
        area_index_sum = []
        sum_length = []
        
        for i in range(area_length ** 2):
            area_index_sum.append([])
            
        for i in range(len(lati_list)):
            lanti_index = int(math.floor((lati_list[i] - min_lati) / lati_length))
            if lanti_index == int(area_length) :
                lanti_index = lanti_index - 1      
            long_index = int(math.floor((long_list[i] - min_long) / long_length))
            if long_index == int(area_length) :
                long_index = long_index - 1  
            area_index_sum[lanti_index * area_length + long_index].append(i)
        
        for i in range(area_length ** 2):
            sum_length.append(len(area_index_sum[i]))
            
        compute_assign = np.round(np.array(sum_length) / len(lati_list) * self.compute_num)
        final_arrange = [0 for i in range(len(self.old_lati_list))]
        for i in range(area_length ** 2):
            if compute_assign[i] == 0:
                continue
            list_temp = area_index_sum[i]
            initial_index = 0
            length_count = compute_assign[i]
          
            while length_count != 0:
                if initial_index == len(list_temp):
                    initial_index = 0
                final_arrange[remain_list[list_temp[initial_index]]] = final_arrange[remain_list[list_temp[initial_index]]] + 1
                initial_index += 1
                length_count = length_count - 1
                
        return final_arrange